from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

import numpy as np

# ============================================
# AFFICHAGE DE LA TABLE
# ============================================
def AfficheTable(L, g):
    """
    Affiche la table de programmation dynamique pour Floyd-Warshall.
    Pour chaque k, affiche une matrice n×n des distances.
    Maximum 3 matrices par ligne.
    """
    sommets = list(g.keys())
    n = len(sommets)

    # Trouver le k max dans L
    k_max = max(key[0] for key in L.keys()) if L else 0

    # Calculer le nombre de lignes et colonnes (max 3 colonnes)
    nb_matrices = k_max + 1
    nb_cols = min(3, nb_matrices)
    nb_rows = (nb_matrices + nb_cols - 1) // nb_cols  # Arrondi supérieur

    # Créer une figure avec une grille de sous-graphiques
    fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(3.5 * nb_cols, 3.5 * nb_rows))

    # Convertir axes en tableau 2D si nécessaire
    if nb_matrices == 1:
        axes = np.array([[axes]])
    elif nb_rows == 1:
        axes = np.array([axes])
    elif nb_cols == 1:
        axes = axes.reshape(-1, 1)

    for k in range(k_max + 1):
        row = k // nb_cols
        col = k % nb_cols
        ax = axes[row, col]

        # Créer une matrice RGB pour ce k
        mat = np.zeros((n, n, 3))

        for i, v in enumerate(sommets):
            for j, w in enumerate(sommets):
                if (k, v, w) in L:
                    mat[i][j] = [0.2, 0.7, 0.3]  # Vert
                else:
                    mat[i][j] = [0.85, 0.85, 0.85]  # Gris clair

        ax.imshow(mat)

        # Afficher les valeurs dans chaque case
        for i, v in enumerate(sommets):
            for j, w in enumerate(sommets):
                if (k, v, w) in L:
                    valeur = L[(k, v, w)]
                    if valeur == float('inf'):
                        txt = '∞'
                    else:
                        txt = str(int(valeur))
                    ax.text(j, i, txt, ha='center', va='center',
                            color='white', fontsize=8, fontweight='bold')

        # Configurer les axes
        ax.set_xticks(range(n))
        ax.set_xticklabels(sommets, fontsize=8)
        ax.set_yticks(range(n))
        ax.set_yticklabels(sommets, fontsize=8)
        ax.set_xlabel('Destination (w)', fontsize=8)
        ax.set_ylabel('Origine (v)', fontsize=8)
        ax.set_title(f'k = {k}', fontsize=10)

        # Quadrillage
        ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
        ax.grid(which='minor', color='black', linewidth=0.5)

    # Masquer les axes inutilisés (si nb_matrices n'est pas multiple de nb_cols)
    for k in range(nb_matrices, nb_rows * nb_cols):
        row = k // nb_cols
        col = k % nb_cols
        axes[row, col].axis('off')

    plt.suptitle('Table de programmation dynamique Floyd-Warshall', fontsize=12)
    plt.tight_layout()
    plt.show()


# Graphe représenté par un dictionnaire d'adjacence
# graphe[v] = [(w1, poids1), (w2, poids2), ...]
graphe = {
    1: [(2, 2), (3, 4)],
    2: [(3, -1), (4, 2)],
    3: [(4, 3), (5, 4)],
    4: [(5, 2)],
    5: []
}


# Variante avec cycle négatif (pour tests)
graphe_neg = {
    1: [(2, 4), (3, 2)],
    2: [(4, 3), (5, 4)],
    3: [(2, -1), (4, 2), (5,4)],
    4: [(2, -5), (5, 2)],
    5: []
}


############################
# Approche bottom-up
############################

def poids_arete(G, u, v):
    poids = np.inf
    for sommet, pds in G[u]:
        if sommet == v:
            return pds
    return poids

def initialiser_L0(G):
    L = {}
    for source in G:
        for dest in G:
            # L[(0,v,v)]
            if source == dest:
                L[(0,source,dest)] = 0
            # L[(0,v,w)]
            else:
                L[(0,source,dest)] = poids_arete(G,source,dest)
    return L

# Questions théoriques:
# (a) : pour un graphe à n sommets, il y a (n+1)*n*n problemes calculés
# (b) : complexité temporelle : coût du travail en O(1) donc au total O(n^3)
# (c) : L(v,w) = min (L(v,w) , L(v,k) + L(k,w)) à chaque itération

def floyd_warshall_bottomup(G,L):
    n = len(G)
    dist = {}

    # Cas de la récurrence avec test des cycles négatifs
    for k in range(1,n+1):
        for v in G:
            for w in G:
                L[(k,v,w)] = min (L[(k-1,v,w)], L[(k-1,v,k)] + L[(k-1,k,w)])
                if v == w and L[(k,v,w)] < 0:
                    return (None, True)

    # Construction du dictionnaire des distances
    dist = {}
    for v in G:
        for w in G:
            dist[(v,w)] = L[(n,v,w)]
    return (dist, False)

L = initialiser_L0(graphe)
dist, cycle_negatif = floyd_warshall_bottomup(graphe,L)
AfficheTable(L,graphe)

L = initialiser_L0(graphe_neg)
dist, cycle_negatif = floyd_warshall_bottomup(graphe_neg,L)
AfficheTable(L,graphe_neg)


##################################
# Implémentation top-down
##################################

def floyd_warshall_topdown_paire(G,v,w):
    L = {}

    def f_rec(k,a,b):
        # Mémoîsation
        if (k,a,b) in L:
            return L[(k,a,b)]

        # Cas de base k == 0
        if k == 0:
            if a == b:             # k=0 et a == b
                L[(k,a,b)] = 0     # L[(0,x,x)] = 0
            else:
                L[(k,a,b)] = poids_arete(G,a,b)     # L[(0,a,b] = l(a,b) | inf
            return L[(k,a,b)]


        # Récurrence
        S1 = f_rec(k-1,a,b)
        S2 = f_rec(k-1,a,k) + f_rec(k-1,k,b)
        L[(k,a,b)] = min(S1,S2)

        # Détection cycle_negatif
        if a == b and L[(k,a,b)] < 0:
            raise ValueError("Cycle négatif détecté")
        return L[(k,a,b)]

    n = len(G)
    return L, f_rec(n,v,w)

# Questions théoriques
# (a) : Même complexité que bottom-up dans le pire des cas O(n^3)
# (b) : Tous les états ne sont pas nécessaire pour terminer la récurrence
# (c) : Un cycle négatif peut apparaître à k = n et pas avant
def floyd_warshall_topdown_toutes_paires(G):
    L = {}

    def f_rec(k,a,b):
        # Mémoîsation
        if (k,a,b) in L:
            return L[(k,a,b)]

        # Cas de base k == 0
        if k == 0:
            if a == b:             # k=0 et a == b
                L[(k,a,b)] = 0     # L[(0,x,x)] = 0
            else:
                L[(k,a,b)] = poids_arete(G,a,b)     # L[(0,a,b] = l(a,b) | inf
            return L[(k,a,b)]


        # Récurrence
        S1 = f_rec(k-1,a,b)
        S2 = f_rec(k-1,a,k) + f_rec(k-1,k,b)
        L[(k,a,b)] = min(S1,S2)

        # Détection cycle_negatif
        if a == b and L[(k,a,b)] < 0:
            raise ValueError("Cycle négatif détecté")
        return L[(k,a,b)]

    n = len(G)
    dist = {}
    try:
        # Calcul des distances par récurrence top-down
        for v in G:
            for w in G:
                dist[(v,w)] = f_rec(n,v,w)
        return (dist,False,L)
    except ValueError as e:
        if str(e) == "Cycle négatif détecté":
            return (None,True,None)
        raise



L, dist = floyd_warshall_topdown_paire(graphe,2,4)
AfficheTable(L,graphe)

dist,cycle_negatif,L = floyd_warshall_topdown_toutes_paires(graphe)
AfficheTable(L,graphe)


###############################
# Reconstruction
###############################
def decision_reconstruction(L, k, v, w):
    if L[(k,v,w)] == L[(k-1,v,w)]:
        return "HERITER"
    elif L[(k,v,w)] == L[(k-1,v,k)] + L[(k-1,k,w)]:
        return "DECOMPOSER"
    else:
        raise ValueError("Incohérence dans la table")

# Question théorique
# (a) : Dans le pire des cas on fait une récursion de k=n à k=0 : O(n)
# (b) : n² chemins de longueur O(n) => en O(n3).
def rec_chemin(L,k,v,w):
    if L[(k,v,w)] == np.inf:
        return []
    if k == 0:
        return [v,w]
    elif decision_reconstruction(L,k,v,w) == "HERITER":
        return rec_chemin(L,k-1,v,w)
    else:
        chemin1 = rec_chemin(L,k-1,v,k)
        chemin2 = rec_chemin(L,k-1,k,w)
        return chemin1 + chemin2[1:]

dist,cycle_negatif,L = floyd_warshall_topdown_toutes_paires(graphe)
print(rec_chemin(L,5,1,5))











































